#include "EdgeTracer.hpp"

namespace zzz{

EdgeTracer::EdgeTracer(void)
{
  start_=-1;
  optrange_=5;
}

EdgeTracer::~EdgeTracer(void)
{
}


//link direction
//3 2 1
//7 p 0
//6 5 4

//only calculate 1 2 3 7 and set 6 5 4 0 accordingly
void EdgeTracer::Prepare(const Imagef &ori)
{
  const float SQRT2=Sqrt(2.0f);
  links_.SetSize(ori.Size());
  memset(links_.Data(),0,sizeof(Vector<8,float>)*links_.size());
  maxlink_=-MAX_FLOAT;
  for (zuint r=0; r<ori.Rows(); r++) for (zuint c=0; c<ori.Cols(); c++)
  {
    if (r!=0 && c!=ori.Cols()-1) 
    {
      //link(1)=(p(2)-p(0))/sqrt(2)
      float link=(ori.At(r-1,c)-ori.At(r,c+1))/SQRT2;
      links_.At(r,c)[1]=links_.At(Vector2ui(r-1,c+1))[6]=link;
      if (maxlink_<link) maxlink_=link;
    }
    if (r!=0 && c!=0 && c!=ori.Cols()-1) 
    {
      //link(2)=(p(1)/2+p(0)-p(3)/2-p(7))/2
      float link=(ori.At(r-1,c+1)/2+ori.At(r,c+1)-ori.At(r-1,c-1)/2-ori.At(r,c-1))/2;
      links_.At(r,c)[2]=links_.At(r-1,c)[5]=link;
      if (maxlink_<link) maxlink_=link;
    }
    if (r!=0 && c!=0) 
    {
      //link(3)=(p(2)-p(7))/sqrt(2)
      float link=(ori.At(r-1,c)-ori.At(r,c-1))/SQRT2;
      links_.At(r,c)[3]=links_.At(r-1,c-1)[4]=link;
      if (maxlink_<link) maxlink_=link;
    }
    if (c!=0 && r!=0 && r!=ori.Rows()-1) 
    {
      //link(7)=(p(3)/2+p(2)-p(6)/2-p(5))/2
      float link=(ori.At(r-1,c-1)/2+ori.At(r-1,c)-ori.At(r+1,c-1)/2-ori.At(r+1,c))/2;
      links_.At(Vector2ui(r,c))[7]=links_.At(Vector2ui(r,c-1))[0]=link;
      if (maxlink_<link) maxlink_=link;
    }
  }
  for (zuint i=0; i<links_.size(); i++) 
  {
    links_.at(i)[0]=maxlink_-links_.at(i)[0];
    links_.at(i)[1]=(maxlink_-links_.at(i)[1])*SQRT2;
    links_.at(i)[2]=maxlink_-links_.at(i)[2];
    links_.at(i)[3]=(maxlink_-links_.at(i)[3])*SQRT2;
    links_.at(i)[4]=(maxlink_-links_.at(i)[4])*SQRT2;
    links_.at(i)[5]=maxlink_-links_.at(i)[5];
    links_.at(i)[6]=(maxlink_-links_.at(i)[6])*SQRT2;
    links_.at(i)[7]=maxlink_-links_.at(i)[7];
  }

  afterPrepare();
}

//only calculate 1 2 3 7 and set 6 5 4 0 accordingly
void EdgeTracer::Prepare(const Image3f &ori)
{
  const float SQRT2=Sqrt(2.0f);
  links_.SetSize(ori.Size());
  memset(links_.Data(),0,sizeof(Vector<8,float>)*links_.size());
  maxlink_=-MAX_FLOAT;
  for (zuint r=0; r<ori.Rows(); r++) for (zuint c=0; c<ori.Cols(); c++)
  {
    if (r!=0 && c!=ori.Cols()-1) 
    {
      //link(1)=(p(2)-p(0))/sqrt(2)
      Vector3f diff=(ori.At(r-1,c)-ori.At(r,c+1))/SQRT2;
      float link=sqrt(diff.LenSqr()/3);
      links_.At(Vector2ui(r,c))[1]=links_.At(Vector2ui(r-1,c+1))[6]=link;
      if (maxlink_<link) maxlink_=link;
    }
    if (r!=0 && c!=0 && c!=ori.Cols()-1) 
    {
      //link(2)=(p(1)/2+p(0)-p(3)/2-p(7))/2
      Vector3f diff=(ori.At(r-1,c+1)/2+ori.At(r,c+1)-ori.At(r-1,c-1)/2-ori.At(r,c-1))/2;
      float link=sqrt(diff.LenSqr()/3);
      links_.At(Vector2ui(r,c))[2]=links_.At(Vector2ui(r-1,c))[5]=link;
      if (maxlink_<link) maxlink_=link;
    }
    if (r!=0 && c!=0) 
    {
      //link(3)=(p(2)-p(7))/sqrt(2)
      Vector3f diff=(ori.At(r-1,c)-ori.At(r,c-1))/SQRT2;
      float link=sqrt(diff.LenSqr()/3);
      links_.At(Vector2ui(r,c))[3]=links_.At(Vector2ui(r-1,c-1))[4]=link;
      if (maxlink_<link) maxlink_=link;
    }
    if (c!=0 && r!=0 && r!=ori.Rows()-1) 
    {
      //link(7)=(p(3)/2+p(2)-p(6)/2-p(5))/2
      Vector3f diff=(ori.At(r-1,c-1)/2+ori.At(r-1,c)-ori.At(r+1,c-1)/2-ori.At(r+1,c))/2;
      float link=sqrt(diff.LenSqr()/3);
      links_.At(Vector2ui(r,c))[7]=links_.At(Vector2ui(r,c-1))[0]=link;
      if (maxlink_<link) maxlink_=link;
    }
  }
  for (zuint i=0; i<links_.size(); i++) 
  {
    links_.at(i)[0]=maxlink_-links_.at(i)[0];
    links_.at(i)[1]=(maxlink_-links_.at(i)[1])*SQRT2;
    links_.at(i)[2]=maxlink_-links_.at(i)[2];
    links_.at(i)[3]=(maxlink_-links_.at(i)[3])*SQRT2;
    links_.at(i)[4]=(maxlink_-links_.at(i)[4])*SQRT2;
    links_.at(i)[5]=maxlink_-links_.at(i)[5];
    links_.at(i)[6]=(maxlink_-links_.at(i)[6])*SQRT2;
    links_.at(i)[7]=maxlink_-links_.at(i)[7];
  }
  afterPrepare();
}

//only calculate 1 2 3 7 and set 6 5 4 0 accordingly
void EdgeTracer::Prepare(const Image4f &ori)
{
  const float SQRT2=Sqrt(2.0f);
  links_.SetSize(ori.Size());
  memset(links_.Data(),0,sizeof(Vector<8,float>)*links_.size());
  maxlink_=-MAX_FLOAT;
  for (zuint r=0; r<ori.Rows(); r++) for (zuint c=0; c<ori.Cols(); c++)
  {
    if (r!=0 && c!=ori.Cols()-1) 
    {
      //link(1)=(p(2)-p(0))/sqrt(2)
      Vector4f diff=(ori.At(r-1,c)-ori.At(r,c+1))/SQRT2;
      float link=sqrt(diff.LenSqr()/3);
      links_.At(Vector2ui(r,c))[1]=links_.At(Vector2ui(r-1,c+1))[6]=link;
      if (maxlink_<link) maxlink_=link;
    }
    if (r!=0 && c!=0 && c!=ori.Cols()-1) 
    {
      //link(2)=(p(1)/2+p(0)-p(3)/2-p(7))/2
      Vector4f diff=(ori.At(r-1,c+1)/2+ori.At(r,c+1)-ori.At(r-1,c-1)/2-ori.At(r,c-1))/2;
      float link=sqrt(diff.LenSqr()/3);
      links_.At(Vector2ui(r,c))[2]=links_.At(Vector2ui(r-1,c))[5]=link;
      if (maxlink_<link) maxlink_=link;
    }
    if (r!=0 && c!=0) 
    {
      //link(3)=(p(2)-p(7))/sqrt(2)
      Vector4f diff=(ori.At(r-1,c)-ori.At(r,c-1))/SQRT2;
      float link=sqrt(diff.LenSqr()/3);
      links_.At(Vector2ui(r,c))[3]=links_.At(Vector2ui(r-1,c-1))[4]=link;
      if (maxlink_<link) maxlink_=link;
    }
    if (c!=0 && r!=0 && r!=ori.Rows()-1) 
    {
      //link(7)=(p(3)/2+p(2)-p(6)/2-p(5))/2
      Vector4f diff=(ori.At(r-1,c-1)/2+ori.At(r-1,c)-ori.At(r+1,c-1)/2-ori.At(r+1,c))/2;
      float link=sqrt(diff.LenSqr()/3);
      links_.At(Vector2ui(r,c))[7]=links_.At(Vector2ui(r,c-1))[0]=link;
      if (maxlink_<link) maxlink_=link;
    }
  }
  for (zuint i=0; i<links_.size(); i++) 
  {
    links_.at(i)[0]=maxlink_-links_.at(i)[0];
    links_.at(i)[1]=(maxlink_-links_.at(i)[1])*SQRT2;
    links_.at(i)[2]=maxlink_-links_.at(i)[2];
    links_.at(i)[3]=(maxlink_-links_.at(i)[3])*SQRT2;
    links_.at(i)[4]=(maxlink_-links_.at(i)[4])*SQRT2;
    links_.at(i)[5]=maxlink_-links_.at(i)[5];
    links_.at(i)[6]=(maxlink_-links_.at(i)[6])*SQRT2;
    links_.at(i)[7]=maxlink_-links_.at(i)[7];
  }
  afterPrepare();
}

void EdgeTracer::FindPath(int end, vector<int>& backpath)
{
  //link direction
  //3 2 1
  //7 p 0
  //6 5 4
  int linelength=links_.Size(1);
  int offsets[8]={+1,-linelength+1,-linelength,-linelength-1,linelength+1,linelength,linelength-1,-1};
  int offsetr[8]={0,-1,-1,-1,1,1,1,0};
  int offsetc[8]={1,1,0,-1,1,0,-1,-1};

  //Dijkstra's algorithm
  Fibdata fibdata;
  while(pathmap.at(end).status!=2)
  {
    Fibdata thisdata=heap.ExtractMin();
    Vector2ui thisrc=links_.ToIndex(thisdata.pos);
    for (int link=0;link<8;link++)
    {
      Vector2i rc(thisrc);
      rc[0]+=offsetr[link];
      rc[1]+=offsetc[link];
      if (!Within<int>(0,rc[0],links_.Size(0)-1)) continue;  //check if out of boundary
      if (!Within<int>(0,rc[1],links_.Size(1)-1)) continue;  //check if out of boundary
      int otherpos=thisdata.pos+offsets[link];
      double maskratio=1;
      if (mask_.at(otherpos)==1) maskratio=0.01;
      PathNode &pathnode=pathmap.at(otherpos);
      if (pathnode.status==2) continue;  //fixed
      float thiscost=thisdata.cost+links_.at(thisdata.pos)[link]*maskratio;  //cost at here plus cost to go there
      if (thiscost<pathmap.at(otherpos).cost)
      {
        pathnode.cost=thiscost;  //update the minimum cost
        pathnode.backpath=7-link;  //store where the minimum cost comes from

        //update heap
        fibdata.cost=thiscost;
        fibdata.pos=otherpos;
        if (pathnode.status==1)  //already in heap, so need to update it
        {
          heap.DecreaseKey(pathnode.node,fibdata);
        }
        else  //not in the heap, need to add(new pos cost is MAX_INT, so code will go here)
        {
          pathnode.node=heap.Insert(fibdata);
        }
      }
    }
    pathmap.at(thisdata.pos).node=NULL;
    pathmap.at(thisdata.pos).status=2;
  }
  
  backpath.clear();
  backpath.push_back(end);
  while(backpath.back()!=start_)
  {
    backpath.push_back(backpath.back() + offsets[pathmap.at(backpath.back()).backpath]);
  }
  return;
}

void EdgeTracer::FindPath(const Vector2ui &end, vector<Vector2ui>& backpath)
{
  vector<int> backpathint;
  FindPath(links_.ToIndex(end), backpathint);
  backpath.clear();
  for (zuint i=0; i<backpathint.size(); i++)
    backpath.push_back(links_.ToIndex(backpathint[i]));
}

void EdgeTracer::SetStart(int start)
{
  start_=start;

  pathmap.SetSize(links_.Size());
  PathNode initpathnode;
  initpathnode.node=NULL;
  initpathnode.cost=MAX_INT;
  initpathnode.backpath=-1;
  initpathnode.status=0;
  for (zuint i=0; i<pathmap.size(); i++)
    pathmap.at(i)=initpathnode;

  heap.Clear();
  Fibdata fibdata;
  fibdata.cost=0;
  fibdata.pos=start_;
  pathmap.at(start_).node=heap.Insert(fibdata);
  
  return;
}

void EdgeTracer::SetStart(const Vector2ui &start)
{
  SetStart(links_.ToIndex(start));
}

int EdgeTracer::OptimizeClick(int start)
{
  Vector2i rc(links_.ToIndex(start));
  int pos=-1;
  float minlink=MAX_FLOAT;
  for (int r=Max<int>(0,rc[0]-optrange_); r<=Min<int>(links_.Size(0)-1,rc[0]+optrange_); r++)
    for (int c=Max<int>(0,rc[1]-optrange_); c<=Min<int>(links_.Size(1)-1,rc[1]+optrange_); c++)
      for (int link=0;link<8;link++)
      {
        int thispos=links_.ToIndex(Vector2ui(r,c));
        if (links_.at(thispos)[link]<minlink)
        {
          minlink=links_.at(thispos)[link];
          pos=thispos;
        }
      }
  return pos;
}

zzz::Vector2ui EdgeTracer::OptimizeClick(const Vector2ui &pos)
{
  return links_.ToIndex(OptimizeClick(links_.ToIndex(pos)));
}

void EdgeTracer::SetMask(const Image<zuchar>& mask)
{
  if (mask_.size()==0) return;
  mask_.SetData(mask.Data());  
  pathmap.SetSize(links_.Size());
  PathNode initpathnode;
  initpathnode.node=NULL;
  initpathnode.cost=MAX_INT;
  initpathnode.backpath=-1;
  initpathnode.status=0;
  for (zuint i=0; i<pathmap.size(); i++)
    pathmap.at(i)=initpathnode;

  heap.Clear();
  Fibdata fibdata;
  fibdata.cost=0;
  fibdata.pos=start_;
  pathmap.at(start_).node=heap.Insert(fibdata);
}

void EdgeTracer::ClearMask()
{
  if (mask_.size()==0) return;
  memset(mask_.Data(),0,mask_.size());
  pathmap.SetSize(links_.Size());
  PathNode initpathnode;
  initpathnode.node=NULL;
  initpathnode.cost=MAX_INT;
  initpathnode.backpath=-1;
  initpathnode.status=0;
  for (zuint i=0; i<pathmap.size(); i++)
    pathmap.at(i)=initpathnode;

  heap.Clear();
  Fibdata fibdata;
  fibdata.cost=0;
  fibdata.pos=start_;
  pathmap.at(start_).node=heap.Insert(fibdata);
}

void EdgeTracer::afterPrepare()
{
  heap.Clear();

  mask_.SetSize(links_.Size());
  memset(mask_.Data(),0,mask_.size());

  start_=-1;
}
}